{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Synthesize data for speculative decoding training\n",
    "\n",
    "The speculative decoding medule needs to learn to predict tokens from the base model. Therefore, we need to prepare the data generated from the base model.\n",
    "Note: if the target base model is a quantized version, the synthesized data should be generated using the quantized model."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, quantize the base model (Llama-3.2-1B-Instruct) into FP8 and export to unified export format."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!python llm_ptq/hf_ptq.py --pyt_ckpt_path meta-llama/Llama-3.2-1B-Instruct --qformat fp8 --batch_size 1 --export_path /tmp/llama3.2_1B_fp8 --export_fmt hf"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, download the Daring-Anteater dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!git clone https://huggingface.co/datasets/nvidia/Daring-Anteater /tmp/Daring-Anteater"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then, lanuch an inference server that will run the quantized base model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!vllm serve /tmp/llama3.2_1B_fp8 --api-key token-abc123 --port 8000  --tensor-parallel-size 1 --quantization=modelopt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Open a new terminal and adapt the fine-tuning data by calling this server.\n",
    "Note: this may take a long time."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!mkdir /tmp/finetune\n",
    "!bash prepare_data.sh --data_path /tmp/Daring-Anteater/train.jsonl --output_path /tmp/finetune/data.jsonl --max_token 2048"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's load the base model and convert it to EAGLE Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import transformers\n",
    "\n",
    "import modelopt.torch.opt as mto\n",
    "import modelopt.torch.speculative as mtsp\n",
    "\n",
    "mto.enable_huggingface_checkpointing()\n",
    "\n",
    "model = transformers.AutoModelForCausalLM.from_pretrained(\n",
    "    \"meta-llama/Llama-3.2-1B-Instruct\", torch_dtype=\"auto\"\n",
    ")\n",
    "config = {\n",
    "    \"eagle_num_layers\": 1,\n",
    "    \"use_input_layernorm_in_first_layer\": True,\n",
    "    \"use_last_layernorm\": False,\n",
    "}\n",
    "mtsp.convert(model, [(\"eagle\", config)])\n",
    "\n",
    "tokenizer = transformers.AutoTokenizer.from_pretrained(\"meta-llama/Llama-3.2-1B-Instruct\")\n",
    "tokenizer.pad_token_id = tokenizer.eos_token_id"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Once synthesized data is ready, we can start training the eagle model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "from dataclasses import dataclass, field\n",
    "\n",
    "from speculative_decoding.eagle_utils import DataCollatorWithPadding, LazySupervisedDataset\n",
    "from transformers import Trainer\n",
    "\n",
    "with open(\"/tmp/finetune/data.jsonl\") as f:\n",
    "    data_json = [json.loads(line) for line in f]\n",
    "train_dataset = LazySupervisedDataset(data_json[: int(len(data_json) * 0.95)], tokenizer=tokenizer)\n",
    "eval_dataset = LazySupervisedDataset(data_json[int(len(data_json) * 0.95) :], tokenizer=tokenizer)\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class TrainingArguments(transformers.TrainingArguments):\n",
    "    cache_dir: str | None = field(default=None)\n",
    "    model_max_length: int = field(\n",
    "        default=4096,\n",
    "        metadata={\n",
    "            \"help\": (\n",
    "                \"Maximum sequence length. Sequences will be right padded (and possibly truncated).\"\n",
    "            )\n",
    "        },\n",
    "    )\n",
    "    dataloader_drop_last: bool = field(default=True)\n",
    "    bf16: bool = field(default=True)\n",
    "\n",
    "\n",
    "training_args = TrainingArguments(\n",
    "    output_dir=\"/tmp/eagle_bf16\",\n",
    "    num_train_epochs=1.0,\n",
    "    per_device_train_batch_size=1,\n",
    "    per_device_eval_batch_size=1,\n",
    ")\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    tokenizer=tokenizer,\n",
    "    args=training_args,\n",
    "    train_dataset=train_dataset,\n",
    "    eval_dataset=eval_dataset,\n",
    "    data_collator=DataCollatorWithPadding(),\n",
    ")\n",
    "trainer._move_model_to_device(model, trainer.args.device)\n",
    "\n",
    "# Manually enable this to return loss in eval\n",
    "trainer.can_return_loss = True\n",
    "# Make sure label_smoother is None\n",
    "assert trainer.label_smoother is None, \"label_smoother is not supported in speculative decoding!\"\n",
    "\n",
    "trainer.train()\n",
    "trainer.save_state()\n",
    "trainer.save_model(training_args.output_dir)\n",
    "tokenizer.save_pretrained(training_args.output_dir)\n",
    "\n",
    "metrics = trainer.evaluate()\n",
    "print(f\"Evaluation results: \\n{metrics}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we have a EAGLE model in BF16 format. Next, we quantize this model into FP8 (PTQ)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import modelopt.torch.quantization as mtq\n",
    "import modelopt.torch.utils.dataset_utils as dataset_utils\n",
    "\n",
    "mto.enable_huggingface_checkpointing()\n",
    "\n",
    "model = transformers.AutoModelForCausalLM.from_pretrained(\"/tmp/eagle_bf16\")\n",
    "tokenizer = transformers.AutoTokenizer.from_pretrained(\"/tmp/eagle_bf16\")\n",
    "\n",
    "calib_dataloader = dataset_utils.get_dataset_dataloader(\n",
    "    dataset_name=\"cnn_dailymail\",\n",
    "    tokenizer=tokenizer,\n",
    "    batch_size=1,\n",
    "    num_samples=512,\n",
    "    device=model.device,\n",
    "    include_labels=False,\n",
    ")\n",
    "\n",
    "quant_cfg = getattr(mtq, \"FP8_DEFAULT_CFG\")\n",
    "quant_cfg[\"quant_cfg\"][\"*output_quantizer\"] = {\n",
    "    \"num_bits\": (4, 3),\n",
    "    \"axis\": None,\n",
    "    \"enable\": True,\n",
    "}\n",
    "\n",
    "calibrate_loop = dataset_utils.create_forward_loop(calib_dataloader, dataloader=calib_dataloader)\n",
    "model = mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)\n",
    "mtq.print_quant_summary(model)\n",
    "\n",
    "model.save_pretrained(\"/tmp/eagle_fp8_ptq\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To maintain the accuracy, we need to finetune the model (QAT)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "training_args.output_dir = \"/tmp/eagle_fp8_qat\"\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    tokenizer=tokenizer,\n",
    "    args=training_args,\n",
    "    train_dataset=train_dataset,\n",
    "    eval_dataset=eval_dataset,\n",
    "    data_collator=DataCollatorWithPadding(),\n",
    ")\n",
    "trainer._move_model_to_device(model, trainer.args.device)\n",
    "\n",
    "# Manually enable this to return loss in eval\n",
    "trainer.can_return_loss = True\n",
    "# Make sure label_smoother is None\n",
    "assert trainer.label_smoother is None, \"label_smoother is not supported in speculative decoding!\"\n",
    "\n",
    "trainer.train()\n",
    "trainer.save_state()\n",
    "trainer.save_model(training_args.output_dir)\n",
    "tokenizer.save_pretrained(training_args.output_dir)\n",
    "\n",
    "metrics = trainer.evaluate()\n",
    "print(f\"Evaluation results: \\n{metrics}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To deploy this model, we need to first export it to a Unified checkpoint."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from accelerate.hooks import remove_hook_from_module\n",
    "\n",
    "from modelopt.torch.export import export_hf_checkpoint\n",
    "\n",
    "# Move meta tensor back to device before exporting.\n",
    "remove_hook_from_module(model, recurse=True)\n",
    "\n",
    "export_hf_checkpoint(\n",
    "    model,\n",
    "    export_dir=\"/tmp/hf_ckpt\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then convert the Unified ckeckpoint to TRTLLM checkpoint."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!python TensorRT-LLM/examples/eagle/convert_checkpoint.py --model_dir /tmp/hf_ckpt --output_dir /tmp/trtllm_ckpt --num_eagle_layers 5 --max_non_leaves_per_layer 4 --max_draft_len 25 --dtype float16"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Last, build a TensorRT-LLM engine."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!trtllm-build --checkpoint_dir /tmp/trtllm_ckpt --output_dir /tmp/trtllm_engine --gemm_plugin float16 --use_paged_context_fmha enable --speculative_decoding_mode eagle  --max_batch_size 4"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To run the EAGLE engine, please refer to [TensorRT-LLM/examples/eagle](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/eagle):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!python ../run.py --engine_dir /tmp/trtllm_engine \\\n",
    "                  --tokenizer_dir /tmp/eagle_fp8_qat \\\n",
    "                  --max_output_len=100 \\\n",
    "                  --eagle_choices=\"[[0],[1],[2],[3],[0,0],[0,1],[0,2],[1,0],[1,1],[2,0],[2,1],[3,0],[0,0,0],[0,0,1],[0,0,2],[0,1,0],[0,1,1],[0,2,0],[0,2,1],[1,0,0],[0,0,0,0],[0,0,0,1],[0,0,0,2],[0,0,0,0,0],[0,0,0,0,1]]\" \\\n",
    "                  --temperature 1.0 \\\n",
    "                  --input_text \"Once upon\""
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py312",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
